package io.gitee.terralian.code.generator.service.db.impl;

import java.sql.Connection;
import java.util.List;

import io.gitee.terralian.code.generator.dao.entity.DataBase;
import io.gitee.terralian.code.generator.dao.service.DataBaseService;
import io.gitee.terralian.code.generator.framework.entity.Result;
import io.gitee.terralian.code.generator.service.db.RemoteDBService;
import io.gitee.terralian.code.generator.service.db.entity.ColumnRef;
import io.gitee.terralian.code.generator.service.db.entity.TableRef;
import cn.hutool.core.util.StrUtil;
import lombok.AllArgsConstructor;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DriverManagerDataSource;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

@Service
@AllArgsConstructor
public class JDBCRemoveDBService implements RemoteDBService {

    private final DataBaseService dataBaseService;

    private static final String TABLE_SQL_TEMPLATE = "SELECT TABLE_NAME, TABLE_COMMENT FROM " +
            "INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}'";

    private static final String COLUMN_SQL_TEMPLATE = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " +
            "WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}' ORDER BY ORDINAL_POSITION";

    @Override
    public Result<String> connectTest(String id) {
        DataBase dataBase = dataBaseService.getById(id);
        Assert.notNull(dataBase, "数据源不存在");
        return connectTest(dataBase);
    }

    @Override
    public Result<String> connectTest(DataBase dataBase) {
        DriverManagerDataSource driverManagerDataSource = buildDriverManagerDataSource(dataBase);
        try (Connection ignore = driverManagerDataSource.getConnection()) {
            return Result.successMsg("success");
        } catch (Exception e) {
            return Result.failMsg(e.getMessage());
        }
    }

    @Override
    public List<TableRef> getTables(String id) {
        DataBase dataBase = dataBaseService.getById(id);
        Assert.notNull(dataBase, "数据源不存在");

        String sql = StrUtil.format(TABLE_SQL_TEMPLATE, dataBase.getDbName());
        JdbcTemplate jdbcTemplate = buildJdbcTemplate(dataBase);
        return jdbcTemplate.query(sql, BeanPropertyRowMapper.newInstance(TableRef.class));
    }

    @Override
    public List<ColumnRef> getColumns(String id, String tableName) {
        DataBase dataBase = dataBaseService.getById(id);
        return getColumns(dataBase, tableName);
    }

    @Override
    public List<ColumnRef> getColumns(DataBase dataBase, String tableName) {
        Assert.notNull(dataBase, "数据源不存在");

        String sql = StrUtil.format(COLUMN_SQL_TEMPLATE, dataBase.getDbName(), tableName);
        JdbcTemplate jdbcTemplate = buildJdbcTemplate(dataBase);
        return jdbcTemplate.query(sql, BeanPropertyRowMapper.newInstance(ColumnRef.class));
    }

    private JdbcTemplate buildJdbcTemplate(DataBase dataBase) {
        DriverManagerDataSource driverManagerDataSource = buildDriverManagerDataSource(dataBase);
        JdbcTemplate jdbcTemplate = new JdbcTemplate();
        jdbcTemplate.setDataSource(driverManagerDataSource);
        return jdbcTemplate;
    }

    private DriverManagerDataSource buildDriverManagerDataSource(DataBase dataBase) {
        DriverManagerDataSource driverManagerDataSource = new DriverManagerDataSource();
        driverManagerDataSource.setUrl(dataBase.getUrl());
        driverManagerDataSource.setUsername(dataBase.getUsername());
        driverManagerDataSource.setPassword(dataBase.getPassword());
        driverManagerDataSource.setDriverClassName(dataBase.getDriverClassName());
        return driverManagerDataSource;
    }
}
